Fix TorchAO v1 group offloading with use_stream=True#14112
Conversation
|
I ran diff --git a/tests/models/testing_utils/quantization.py b/tests/models/testing_utils/quantization.py
index 95b6b0fc6..ea81538c3 100644
--- a/tests/models/testing_utils/quantization.py
+++ b/tests/models/testing_utils/quantization.py
@@ -1377,7 +1377,7 @@ class TorchAoCompileTesterMixin(TorchAoConfigMixin, QuantizationCompileTesterMix
@pytest.mark.parametrize("quant_type", ["int8wo"], ids=["int8wo"])
def test_torchao_torch_compile_with_group_offload(self, quant_type):
- self._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type])
+ self._test_torch_compile_with_group_offload(TorchAoConfigMixin.TORCHAO_QUANT_TYPES[quant_type], use_stream=True)
@is_ggufand it passed. What am I missing? |
|
Thanks @sayakpaul, you are not missing anything. Your test uses the TorchAO v2 path, and that path passes because the resulting I narrowed this PR to the legacy The follow-up patch no longer skips pinning for all TorchAO tensors. It keeps the normal pinned-memory path and only falls back to the CPU copy when a TorchAO tensor raises |
|
It is not recommended to use v1. If things work fine with v2, then I don't think any fixes (like the ones introduced in this PR) are needed at all. |
|
Thanks @sayakpaul, that makes sense. Since this is now scoped to the deprecated v1 path and the current v2/Float8 paths work, I'm going to close this PR. |
What does this PR do?
Refs #13281.
This PR is scoped to the legacy TorchAO int8 weight-only path that produces
AffineQuantizedTensor(Int8WeightOnlyConfig(version=1)). The currentFloat8WeightOnlyConfigpath reported in #13281, and the int8 compile path that uses TorchAOversion=2, both supportis_pinned()andpin_memory()on current main.The streamed group-offload path keeps a CPU copy of each tensor and normally pins that copy before transferring a group back to the accelerator.
AffineQuantizedTensoris still a TorchAO tensor subclass, so_to_cpu()must calltensor.cpu(), but its pinning ops raiseNotImplementedError: ... aten.is_pinned.This PR keeps pinned memory for tensors whose pinning ops work, including TorchAO v2 tensors, and falls back to the CPU copy only when a TorchAO tensor does not implement those pinning ops.
End-to-end reproduction
Environment: NVIDIA RTX 4090,
torch==2.8.0+cu128,torchao==0.17.0.The script uses the public tiny Flux pipeline, quantizes the transformer with
Int8WeightOnlyConfig(version=1), enablespipe.transformer.enable_group_offload(..., use_stream=True), moves the remaining modules to CUDA, and runspipe(...).Before:
After:
Additional checks:
Before submitting
.ai/review-rules.md?Who can review?
cc @sayakpaul